Skip to content

Bria 3 2 pipeline #12010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Aug 20, 2025
Merged

Bria 3 2 pipeline #12010

merged 42 commits into from
Aug 20, 2025

Conversation

galbria
Copy link
Contributor

@galbria galbria commented Jul 29, 2025

What does this PR do?

Implementing Bria 3.2 pipeline and BriaTransformer2D
issue: here

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

galbria added 6 commits July 28, 2025 08:37
- Introduced `BriaTransformer2DModel` and `BriaPipeline` for enhanced image generation capabilities.
- Updated import structures across various modules to include the new Bria components.
- Added utility functions and output classes specific to the Bria pipeline.
- Implemented tests for the Bria pipeline to ensure functionality and output integrity.
@SahilCarterr
Copy link
Contributor

Error During Model loading

Code

import torch
from diffusers import BriaPipeline
pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
pipe.to(device="cuda")
prompt = "A asian girl with red top and blue jeans"
negative_prompt = "Logo,Watermark,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"

images = pipe(prompt=prompt, negative_prompt=negative_prompt, height=1024, width=1024).images[0]
images

Logs

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[/tmp/ipython-input-4-3624485676.py](https://localhost:8080/#) in <cell line: 0>()
----> 1 pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
      2 pipe.to(device="cuda")
      3 prompt = "A asian girl with red top and blue jeans"
      4 negative_prompt = "Logo,Watermark,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
      5 

3 frames
[/content/diffusers/src/diffusers/pipelines/pipeline_utils.py](https://localhost:8080/#) in download(cls, pretrained_model_name, **kwargs)
   1536 
   1537             if load_components_from_hub and not trust_remote_code:
-> 1538                 raise ValueError(
   1539                     f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
   1540                     f"load the model. You can inspect the repository content at {', '.join([f'[https://hf.co/{pretrained_model_name}/{k}/{v}.py](https://hf.co/%7Bpretrained_model_name%7D/%7Bk%7D/%7Bv%7D.py)' for k, v in custom_components.items()])}.\n"

ValueError: The repository for briaai/BRIA-3.2 contains custom code in transformer/transformer_bria which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/briaai/BRIA-3.2/transformer/transformer_bria.py.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.

FIX

To fix it transformer_bria.py should be removed and model_index.json should be edited for the same from the model repo on HuggingFace .Always using trust_remote_code=True not efficient for model loading

@galbria

@SahilCarterr
Copy link
Contributor

Run make style make quality make fix-copies in order to run checks on github.
@galbria

@galbria
Copy link
Contributor Author

galbria commented Jul 30, 2025

hey @SahilCarterr i fixed the Huggignface model repo and runt the makes

…dND class for rotary position embedding, and enhance Timestep and TimestepProjEmbeddings classes. Add utility functions for handling negative prompts and generating original sigmas in pipeline_bria.py.
@SahilCarterr
Copy link
Contributor

Hey @galbria Checkout the above reviews and edit accordingly

@SahilCarterr
Copy link
Contributor

Hey @galbria You need to fix some tests in test_pipeline_bria.py . Below are the error logs

Error Logs
=================================== FAILURES ===================================
________________ BriaPipelineSlowTests.test_bria_inference_bf16 ________________

self = <tests.pipelines.bria.test_pipeline_bria.BriaPipelineSlowTests testMethod=test_bria_inference_bf16>

    def test_bria_inference_bf16(self):
>       pipe = self.pipeline_class.from_pretrained(
            self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None
        )

diffusers/tests/pipelines/bria/test_pipeline_bria.py:270: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:114: in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
diffusers/src/diffusers/pipelines/pipeline_utils.py:1093: in from_pretrained
    model = pipeline_class(**init_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = BriaPipeline {
  "_class_name": "BriaPipeline",
  "_diffusers_version": "0.35.0.dev0",
  "feature_extractor": [
    nu...ansformer": [
    "diffusers",
    "BriaTransformer2DModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

transformer = BriaTransformer2DModel(
  (pos_embed): EmbedND()
  (time_embed): TimestepProjEmbeddings(
    (time_proj): Timesteps()
...((2304,), eps=1e-06, elementwise_affine=False)
  )
  (proj_out): Linear(in_features=2304, out_features=16, bias=True)
)
scheduler = FlowMatchEulerDiscreteScheduler {
  "_class_name": "FlowMatchEulerDiscreteScheduler",
  "_diffusers_version": "0.35.0....beta_sigmas": false,
  "use_dynamic_shifting": true,
  "use_exponential_sigmas": false,
  "use_karras_sigmas": false
}

vae = AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
... Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
  (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
)
text_encoder = None, tokenizer = None, image_encoder = None
feature_extractor = None

    def __init__(
        self,
        transformer: BriaTransformer2DModel,
        scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
        vae: AutoencoderKL,
        text_encoder: T5EncoderModel,
        tokenizer: T5TokenizerFast,
        image_encoder: CLIPVisionModelWithProjection = None,
        feature_extractor: CLIPImageProcessor = None,
    ):
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            transformer=transformer,
            scheduler=scheduler,
            image_encoder=image_encoder,
            feature_extractor=feature_extractor,
        )
    
        # TODO - why different than offical flux (-1)
        self.vae_scale_factor = (
            2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
        )
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        self.default_sample_size = 64  # due to patchify=> 128,128 => res of 1k,1k
    
        # T5 is senstive to precision so we use the precision used for precompute and cast as needed
>       self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
                            ^^^^^^^^^^^^^^^^^^^^
E       AttributeError: 'NoneType' object has no attribute 'to'

diffusers/src/diffusers/pipelines/bria/pipeline_bria.py:189: AttributeError
----------------------------- Captured stderr call -----------------------------
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]Loaded vae as AutoencoderKL from `vae` subfolder of briaai/BRIA-3.2.
Loading pipeline components...:  33%|███▎      | 1/3 [00:00<00:00,  6.94it/s]Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of briaai/BRIA-3.2.
Loaded transformer as BriaTransformer2DModel from `transformer` subfolder of briaai/BRIA-3.2.
Loading pipeline components...: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]
_____________________ BriaPipelineSlowTests.test_to_dtype ______________________

self = <tests.pipelines.bria.test_pipeline_bria.BriaPipelineSlowTests testMethod=test_to_dtype>

    def test_to_dtype(self):
>       components = self.get_dummy_components()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
E       AttributeError: 'BriaPipelineSlowTests' object has no attribute 'get_dummy_components'

diffusers/tests/pipelines/bria/test_pipeline_bria.py:318: AttributeError
_________________ BriaPipelineNightlyTests.test_bria_inference _________________

self = <tests.pipelines.bria.test_pipeline_bria.BriaPipelineNightlyTests testMethod=test_bria_inference>

    def test_bria_inference(self):
        pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
        pipe.to(torch_device)
    
        prompt = "a close-up of a smiling cat, high quality, realistic"
        image = pipe(prompt=prompt, num_inference_steps=5, output_type="np").images[0]
    
>       image_slice = image[0, :10, :10, 0].flatten()
                      ^^^^^^^^^^^^^^^^^^^^^
E       IndexError: too many indices for array: array is 3-dimensional, but 4 were indexed

diffusers/tests/pipelines/bria/test_pipeline_bria.py:349: IndexError
----------------------------- Captured stdout call -----------------------------
Using dynamic shift in pipeline with sequence length 4096
----------------------------- Captured stderr call -----------------------------
Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Loading checkpoint shards:  33%|███▎      | 1/3 [00:02<00:04,  2.01s/it]
Loading checkpoint shards:  67%|██████▋   | 2/3 [00:04<00:02,  2.01s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it]
Loaded text_encoder as T5EncoderModel from `text_encoder` subfolder of briaai/BRIA-3.2.
Loading pipeline components...:  20%|██        | 1/5 [00:06<00:24,  6.15s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loaded tokenizer as T5TokenizerFast from `tokenizer` subfolder of briaai/BRIA-3.2.
Loading pipeline components...:  40%|████      | 2/5 [00:06<00:08,  2.70s/it]Loaded vae as AutoencoderKL from `vae` subfolder of briaai/BRIA-3.2.
Loading pipeline components...:  60%|██████    | 3/5 [00:06<00:03,  1.52s/it]Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of briaai/BRIA-3.2.
Loaded transformer as BriaTransformer2DModel from `transformer` subfolder of briaai/BRIA-3.2.
Loading pipeline components...: 100%|██████████| 5/5 [00:08<00:00,  1.72s/it]
100%|██████████| 5/5 [00:08<00:00,  1.77s/it]
=========================== short test summary info ============================
FAILED diffusers/tests/pipelines/bria/test_pipeline_bria.py::BriaPipelineSlowTests::test_bria_inference_bf16 - AttributeError: 'NoneType' object has no attribute 'to'
FAILED diffusers/tests/pipelines/bria/test_pipeline_bria.py::BriaPipelineSlowTests::test_to_dtype - AttributeError: 'BriaPipelineSlowTests' object has no attribute 'get_dummy_...
FAILED diffusers/tests/pipelines/bria/test_pipeline_bria.py::BriaPipelineNightlyTests::test_bria_inference - IndexError: too many indices for array: array is 3-dimensional, but 4 were ...

@galbria
Copy link
Contributor Author

galbria commented Aug 1, 2025

hey @SahilCarterr fixed the tests.

@linoytsaban linoytsaban requested review from asomoza and yiyixuxu August 7, 2025 07:28
@linoytsaban
Copy link
Collaborator

@bot /style

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @galbria!

@galbria
Copy link
Contributor Author

galbria commented Aug 15, 2025

@a-r-r-o-w @SahilCarterr sorry for the additional commit after you have already approved😬
ive just fix the doctree error

and thank you guys for the help!!!

@galbria galbria requested a review from a-r-r-o-w August 15, 2025 03:12
@drcgrp
Copy link

drcgrp commented Aug 15, 2025

Please add lora training script from your examples repo here.
Also the import in that repo references a function that doesn't exist in this pipeline atm. Please fix that too

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza
Copy link
Member

asomoza commented Aug 18, 2025

I tested this model with group offloading so more users can try it without any quality loss, this is the code I used:

import torch

from diffusers import BriaPipeline


onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.float16)
pipe.transformer.enable_group_offload(
    onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level"
)
pipe.to("cuda")

pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16)
for block in pipe.text_encoder.encoder.block:
    block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)

if pipe.vae.config.shift_factor == 0:
    pipe.vae.to(dtype=torch.float32)

prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
image = pipe(prompt).images[0]
image.save("bria.png")

So it uses the defaults from the pipeline, however I get this image that seems a lot worse that the one in the repo here, is this expected @galbria?

The one generated with the code seems a lot more fake, like an old 3D render.

bria

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one minor comment about the pipeline. Otherwise LGTM 👍🏽

@galbria
Copy link
Contributor Author

galbria commented Aug 18, 2025

I tested this model with group offloading so more users can try it without any quality loss, this is the code I used:

import torch

from diffusers import BriaPipeline


onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.float16)
pipe.transformer.enable_group_offload(
    onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level"
)
pipe.to("cuda")

pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16)
for block in pipe.text_encoder.encoder.block:
    block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)

if pipe.vae.config.shift_factor == 0:
    pipe.vae.to(dtype=torch.float32)

prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
image = pipe(prompt).images[0]
image.save("bria.png")

So it uses the defaults from the pipeline, however I get this image that seems a lot worse that the one in the repo here, is this expected @galbria?

The one generated with the code seems a lot more fake, like an old 3D render.

bria

I used bfloat16 instead of float16 and got a better result. Does it help?

@asomoza
Copy link
Member

asomoza commented Aug 19, 2025

I used bfloat16 instead of float16 and got a better result. Does it help?

@galbria I was just commenting the difference with the example image, if you're good with the result and it's expected, it's your decision. My comment is definitely not a merge blocker.

@galbria
Copy link
Contributor Author

galbria commented Aug 19, 2025

@SahilCarterr @asomoza @a-r-r-o-w @yiyixuxu What else should I do so we can merge it?

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one last question

@a-r-r-o-w a-r-r-o-w merged commit 7993be9 into huggingface:main Aug 20, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants